import torch
try:
    import torch_npu
except ImportError:
    print("torch_npu not found, please install it")

exp_dir = "/home/work_nfs16/xlgeng/code/wenet_undersdand_and_speech_xlgeng_emotion_only/examples/wenetspeech/whisper/exp/two_stage_train/stage_1_new"
exp_dir = "/home/work_nfs16/xlgeng/code/wenet_undersdand_and_speech_xlgeng_emotion_only/examples/wenetspeech/whisper/exp/two_stage_train/stage_1_new"
exp_dir = "/home/work_nfs16/xlgeng/code/wenet_undersdand_and_speech_xlgeng_emotion_only/examples/wenetspeech/whisper/exp/two_stage_train/stage_2"
exp_dir="/home/work_nfs16/xlgeng/code/wenet_undersdand_and_speech_xlgeng_emotion_only/examples/wenetspeech/whisper/exp/two_stage_train/stage_2_plus_meld"
pt_name="step_9999"
exp_dir="/home/work_nfs16/xlgeng/code/wenet_undersdand_and_speech_xlgeng_emotion_only/examples/wenetspeech/whisper/exp/two_stage_train/only_ssl_from_zero"
pt_name="step_19999"
weight_dict = torch.load(f"{exp_dir}/{pt_name}/mp_rank_00_model_states.pt", map_location='cpu')['module']
print(weight_dict.keys())
torch.save(weight_dict, f"{exp_dir}/{pt_name}.pt")
